{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b587e6a0",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/huseinzol05/translation-ms-en-small/resolve/main/model.pb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b183d8b",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "707378f9",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0cc6dc80",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# g = load_graph('model.pb')\n",
    "# graph_def = g.as_graph_def()\n",
    "# for node in graph_def.node:\n",
    "#     n = g.get_tensor_by_name(f'{node.name}:0')\n",
    "#     if 'encdec_attention/multihead_attention/dot_product_attention' in node.name:\n",
    "#         print(node.name, n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e3904aea",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/tf-nvidia/lib/python3.8/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/ubuntu/tf-nvidia/lib/python3.8/site-packages/tensorflow_gan/python/estimator/tpu_gan_estimator.py:42: The name tf.estimator.tpu.TPUEstimator is deprecated. Please use tf.compat.v1.estimator.tpu.TPUEstimator instead.\n",
      "\n",
      "/home/ubuntu/tf-nvidia/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from tensor2tensor import models\n",
    "from tensor2tensor import problems\n",
    "from tensor2tensor.layers import common_layers\n",
    "from tensor2tensor.utils import trainer_lib\n",
    "from tensor2tensor.utils import t2t_model\n",
    "from tensor2tensor.utils import registry\n",
    "from tensor2tensor.utils import metrics\n",
    "from tensor2tensor.data_generators import problem\n",
    "from tensor2tensor.data_generators import text_problems\n",
    "from tensor2tensor.data_generators import translate\n",
    "from tensor2tensor.utils import registry"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cc68258d",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/huseinzol05/bpe/resolve/main/ms-en.subwords\n",
    "\n",
    "import text_encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d04013b5",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "encoder = text_encoder.SubwordTextEncoder('ms-en.subwords')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "35eeac95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([197, 3728, 2744, 18569, 6057, 10436],\n",
       " [55, 227, 4311, 3891, 10901, 225, 1009])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = encoder.encode('saya suka ayam sangatttta')\n",
    "y = encoder.encode('i like chickoon so much')\n",
    "x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "9a923534",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['saya', 'suka', 'ayam', 'sangatttta']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder._subtoken_ids_to_tokens(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b0b9c4b8",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "@registry.register_problem\n",
    "class TRANSLATION32k(translate.TranslateProblem):\n",
    "\n",
    "    @property\n",
    "    def additional_training_datasets(self):\n",
    "        \"\"\"Allow subclasses to add training datasets.\"\"\"\n",
    "        return []\n",
    "    \n",
    "    def feature_encoders(self, data_dir):\n",
    "        encoder = text_encoder.SubwordTextEncoder('ms-en.subwords')\n",
    "        return {'inputs': encoder, 'targets': encoder}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cfb27e55",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "PROBLEM = 'translatio_n32k'\n",
    "problem = problems.problem(PROBLEM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "635628f0",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, HPARAMS = \"transformer_base\", DATA_DIR = 't2t/data'):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        maxlen_decode = 50 + tf.reduce_max(self.X_seq_len)\n",
    "        \n",
    "        x = tf.expand_dims(tf.expand_dims(self.X, -1), -1)\n",
    "        y = tf.expand_dims(tf.expand_dims(self.Y, -1), -1)\n",
    "        \n",
    "        features = {\n",
    "            \"inputs\": x,\n",
    "            \"targets\": y,\n",
    "            \"target_space_id\": tf.constant(1, dtype=tf.int32),\n",
    "        }\n",
    "        self.features = features\n",
    "        \n",
    "        Modes = tf.estimator.ModeKeys\n",
    "        hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)\n",
    "        self.hparams = hparams\n",
    "        translate_model = registry.model('transformer')(hparams, Modes.TRAIN)\n",
    "        self.translate_model = translate_model\n",
    "        self.out = translate_model(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2d23c9f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/30799181.py:4: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/30799181.py:4: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Setting T2TModel mode to 'train'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using variable initializer: uniform_unit_scaling\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7f10166f9c10> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function framework at 0x7f10166f9c10> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function framework at 0x7f10166f9c10> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: LIVE_VARS_IN\n",
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25880_512.bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_25880_512.bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25880_512.targets_bottom\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming feature 'targets' with symbol_modality_25880_512.targets_bottom\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Building model body\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7f1014fbcca0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function attention_bias_to_padding at 0x7f1014fbcca0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function attention_bias_to_padding at 0x7f1014fbcca0> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: \n",
      "WARNING:tensorflow:Entity <function layers at 0x7f101c152940> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f100d300940>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Entity <function layers at 0x7f101c152940> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f100d300940>\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Entity <function layers at 0x7f101c152940> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: If not matching a CFG node, must be a block statement: <gast.gast.ImportFrom object at 0x7f100d300940>\n",
      "INFO:tensorflow:Transforming body output with symbol_modality_25880_512.top\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Transforming body output with symbol_modality_25880_512.top\n"
     ]
    }
   ],
   "source": [
    "model = Model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f95dcea0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'transformer/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_0/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_1/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_1/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_1/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_1/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_2/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_2/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_2/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_2/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_3/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_3/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_3/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_3/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_4/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_4/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_4/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_4/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_5/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_5/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/encoder/layer_5/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/encoder/layer_5/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_0/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_0/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_0/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_0/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_1/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_1/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_1/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_1/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_1/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_1/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_1/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_1/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_2/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_2/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_2/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_2/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_2/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_2/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_2/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_2/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_3/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_3/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_3/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_3/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_3/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_3/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_3/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_3/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_4/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_4/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_4/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_4/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_4/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_4/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_4/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_4/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_5/self_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_5/self_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_5/self_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_5/self_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_5/encdec_attention/multihead_attention/dot_product_attention': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_5/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " 'transformer/body/decoder/layer_5/encdec_attention/multihead_attention/dot_product_attention/logits': <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_5/encdec_attention/multihead_attention/dot_product_attention/add:0' shape=(?, 8, ?, ?) dtype=float32>}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.translate_model.attention_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "649dffc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.hparams.num_hidden_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "5b863d15",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "attn = []\n",
    "for i in range(model.hparams.num_hidden_layers):\n",
    "    encdec_att = model.translate_model.attention_weights[\n",
    "      \"transformer/body/decoder/layer_%i/encdec_attention/multihead_attention/dot_product_attention\" % i]\n",
    "    attn.append(encdec_att)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3d2a497a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_1/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_2/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_3/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_4/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>,\n",
       " <tf.Tensor 'transformer/parallel_0_4/transformer/transformer/body/decoder/layer_5/encdec_attention/multihead_attention/dot_product_attention/attention_weights:0' shape=(?, 8, ?, ?) dtype=float32>]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5abc9872",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.translate_model.attention_weights['transformer/body/decoder/layer_0/encdec_attention/multihead_attention/dot_product_attention/logits']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "07de1aa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/3951536174.py:1: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/3951536174.py:1: The name tf.InteractiveSession is deprecated. Please use tf.compat.v1.InteractiveSession instead.\n",
      "\n",
      "2022-06-27 00:21:31.469195: I tensorflow/core/platform/profile_utils/cpu_utils.cc:109] CPU Frequency: 2496000000 Hz\n",
      "2022-06-27 00:21:31.469533: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x58af760 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
      "2022-06-27 00:21:31.469545: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version\n",
      "2022-06-27 00:21:31.470592: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1\n",
      "2022-06-27 00:21:31.473198: E tensorflow/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
      "2022-06-27 00:21:31.473213: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: huseincomel-desktop\n",
      "2022-06-27 00:21:31.473216: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: huseincomel-desktop\n",
      "2022-06-27 00:21:31.473254: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 470.129.6\n",
      "2022-06-27 00:21:31.473270: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 470.129.6\n",
      "2022-06-27 00:21:31.473273: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 470.129.6\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "c1c10ebd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/1872165635.py:1: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /tmp/ipykernel_17268/1872165635.py:1: The name tf.global_variables_initializer is deprecated. Please use tf.compat.v1.global_variables_initializer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f03b67c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = sess.run(attn, feed_dict = {\n",
    "    model.X: [x],\n",
    "    model.Y: [y]\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "a72e9724",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 (1, 8, 7, 6)\n",
      "1 (1, 8, 7, 6)\n",
      "2 (1, 8, 7, 6)\n",
      "3 (1, 8, 7, 6)\n",
      "4 (1, 8, 7, 6)\n",
      "5 (1, 8, 7, 6)\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(o)):\n",
    "    print(i, o[i].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "e93039e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 7)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x), len(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "df922ae6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 7)"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = np.mean(o[-1][0], axis =0).T\n",
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "255b70d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "67b737ef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(7,)"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean([a[0], a[0]], axis = 0).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "id": "579022d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def merged(tokens, weights):\n",
    "    i = 0\n",
    "    n_tokens = len(tokens)\n",
    "    new_paired_tokens = []\n",
    "    while i < n_tokens:\n",
    "        current_token = tokens[i]\n",
    "        current_weight = weights[i]\n",
    "        if not current_token.endswith('_'):\n",
    "            merged_token = ''\n",
    "            merged_weight = []\n",
    "            while (\n",
    "                not current_token.endswith('_')\n",
    "            ):\n",
    "                merged_token = merged_token + current_token.replace('_', '')\n",
    "                merged_weight.append(current_weight)\n",
    "                i = i + 1\n",
    "                current_token = tokens[i]\n",
    "                current_weight = weights[i]\n",
    "            merged_token = merged_token + tokens[i]\n",
    "            merged_weight.append(weights[i])\n",
    "            merged_weight = np.mean(merged_weight, axis=0)\n",
    "            new_paired_tokens.append((merged_token, merged_weight))\n",
    "            i = i + 1\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_weight))\n",
    "            i = i + 1\n",
    "    words = [\n",
    "        i[0].replace('▁', '') for i in new_paired_tokens\n",
    "    ]\n",
    "    weights = [i[1] for i in new_paired_tokens]\n",
    "    return words, np.array(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "id": "8868e807",
   "metadata": {},
   "outputs": [],
   "source": [
    "concatenated = [encoder._subtoken_id_to_subtoken_string(s) for s in x]\n",
    "c_x, a_x = merged(concatenated, a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "72c1cf18",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 7)"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "id": "653a9938",
   "metadata": {},
   "outputs": [],
   "source": [
    "concatenated = [encoder._subtoken_id_to_subtoken_string(s) for s in y]\n",
    "c_y, a_y = merged(concatenated, a_x.T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "fb6666dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.18293898, 0.20045616, 0.19267084, 0.19390467, 0.20172535],\n",
       "       [0.12858678, 0.14654504, 0.1563607 , 0.17914012, 0.15539902],\n",
       "       [0.1781643 , 0.158537  , 0.1996336 , 0.204992  , 0.18645568],\n",
       "       [0.17010331, 0.1648206 , 0.15044494, 0.1406544 , 0.15213999]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 114,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_y.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "id": "69fb182c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 5)"
      ]
     },
     "execution_count": 112,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(c_x), len(c_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a03cd5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf1",
   "language": "python",
   "name": "tf1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
